# -*- coding: utf-8 -*-
"""
Created on Sat Apr  1 15:09:05 2023

@author: 29672366
"""

import os
import re
from typing import List, Tuple

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
import torch.nn.functional as F
import numpy as np
from utils.DigitEmbedding import DigitEmbedding
from utils.SelfAttentionEncoder import SelfAttentionEncoder

class QAData(Dataset):
    def __init__(self, data_path: str, 
                 tokenizer: BertTokenizer, 
                 max_seq_length: int,
                 num_digits = 6,
                 historynum = 1):
        self.num_digits = num_digits
        # 转换 tokenizer 编码为数字编码
        self.digit_embedding = DigitEmbedding(self.num_digits, max_seq_length)
        self.tokenizer = tokenizer
        self.max_seq_length = max_seq_length
        self.historynum = historynum
        self.data = self.load_data(data_path)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
        inputs, target = self.data[index]
        return inputs, target

    def preprocess_inputs(self, inputs, max_length):
        seq_length = len(inputs)
        if seq_length < max_length:
            num_to_pad = max_length - seq_length
            padding = [[[0] * self.num_digits] * len(inputs[0])] * num_to_pad
            inputs = np.concatenate((padding, inputs), axis=0)
        else:
            inputs = inputs[-max_length:]
        return inputs

    
    def load_data(self, data_path: str) -> List[Tuple[str, str]]:
        packdata = []
        for filename in os.listdir(data_path):
            if filename.endswith(".txt"):
                inputs = []
                targets = []
                with open(os.path.join(data_path, filename), "r", encoding="utf-8") as f:
                    lines = f.readlines()
                    question = None
                    answer = []
                    for line in lines:
                        line = line.strip()
                        if line.startswith("Q:"):
                            if question is not None and answer:
                                # 将回答列表转换为单个字符串，并对其进行编码
                                target = " ".join(answer)
                                target = self.encode_text(target)
                                question = self.encode_text(question)  # 添加编码操作
                                inputs.append(question)
                                targets.append(target)
                                lenpairs = len(inputs)
                                if lenpairs >= self.historynum:
                                    inputs = inputs[len(inputs)-self.historynum:]
                                    tensor_inputs = np.stack(inputs, axis=0)
                                    targets = inputs[len(targets)-self.historynum:]
                                    tensor_targets = np.stack(targets, axis=0)
                                    packdata.append((tensor_inputs,tensor_targets))
                                else:
                                    tensor_inputs = self.preprocess_inputs(np.stack(inputs, axis=0),self.historynum)
                                    tensor_targets = self.preprocess_inputs(np.stack(targets, axis=0),self.historynum)
                                    
                                    
                                answer = []
                            question = line.strip().replace("Q:", "").strip()
                        elif line.startswith("A:"):
                            answer.append(line.strip().replace("A:", "").strip())
                        else:
                            if answer:
                                answer[-1] += " " + line.strip()
                    # 处理最后一个问题的回答
                    if question is not None and answer:
                        target = " ".join(answer)
                        target = self.encode_text(target)
                        question = self.encode_text(question)  # 添加编码操作
                        inputs.append(question)
                        targets.append(target)
                        if len(inputs) >= self.historynum:
                            inputs = inputs[len(inputs)-self.historynum:]
                            tensor_inputs = np.stack(inputs, axis=0)
                            targets = inputs[len(targets)-self.historynum:]
                            tensor_targets = np.stack(targets, axis=0)
                            packdata.append((tensor_inputs,tensor_targets))
                        else:
                            tensor_inputs = self.preprocess_inputs(np.stack(inputs, axis=0),self.historynum)
                            tensor_targets = self.preprocess_inputs(np.stack(targets, axis=0),self.historynum)  
        return packdata

    def encode_text(self, text: str) -> torch.Tensor:
        #print(text)
        encoded_text = self.tokenizer.encode(text, add_special_tokens=False)
        padding = [0] * (self.max_seq_length - len(encoded_text))
        encoded_text += padding
        encoded_text = self.digit_embedding.encode(encoded_text)
        #print(encoded_text)
        return torch.tensor(encoded_text).view(-1,self.num_digits)

class QAModel(torch.nn.Module):
    def __init__(self, max_seq_length: int,
                 input_max_num_sequences: int, num_digits: int):
        super().__init__()
        self.max_seq_length = max_seq_length
        self.input_max_num_sequences = input_max_num_sequences

        self.model = SelfAttentionEncoder(max_seq_length=self.max_seq_length,
                                          input_max_num_sequences=self.input_max_num_sequences,
                                          num_digits=num_digits)

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        return self.model(inputs)

    def save(self, path):
        torch.save(self.state_dict(), path)

    def load(self, path):
        self.load_state_dict(torch.load(path))
   
